import json

import numpy as np
import torch


from Causal_MNIST_Images.GroundTruth.CausalGraph_Mnist import getdoKey
from Causal_MNIST_Images.GroundTruth.Synthetic_Distribution_Mnist import get_synthetic_dist, get_intv_dist, \
    get_bayesian_network
from Causal_MNIST_Images.mnistControllerModel import get_generated_labels
from Causal_MNIST_Images.DigitImageGeneration.mnist_image_generation import plot_dataset_digits, plot_trained_digits
from Causal_Partial_Mnist.RejectionSampling_Optimized import rejection_sampling_optimized
from ModularUtils.ControllerConstants import map_dictfill_to_discrete
from ModularUtils.FunctionsConstant import asKey
from ModularUtils.FunctionsDistribution import match_with_true_dist


def get_observational_loss(Exp, obs_vars, label_generators, tvd_diff, kl_diff):
    feat= "feature"
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, obs_vars, Exp.Synthetic_Sample_Size, hard=True)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)

    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, Exp.label_names, obs_bn[feat])
    query_str = getdoKey(obs_vars, [])  # getting the scm saving file name
    true_dist_dict = get_intv_dist(Exp, obs_vars, [], query_str)

    tvd, kl = match_with_true_dist(Exp, obs_vars, generated_labels_full, true_dist_dict, feat, doPrint=False)

    tvd_diff[query_str].append(tvd)
    kl_diff[query_str].append(kl)

    return tvd_diff, kl_diff


def get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff, kl_diff):
    feat="feature"

    for query in Exp.interv_queries:

        if bool(set(query["obs"]) & set(cur_mechs)) ==False:
            continue

        compare_Var = list(query["intervs"][0].keys())  #getting the intervened variables
        query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
        obs_dist = get_intv_dist(Exp, compare_Var , dict({}), query_str) # getting the obs distribution of intv variables

        # {"obs": obs_vars, "intervs": key_val, "expr": intervention["expr"]}
        tvd_sum = 0
        kl_sum = 0
        for intv_key in query["intervs"]:


            query_string= getdoKey(query["obs"], intv_key)
            true_dist= get_intv_dist(Exp, query["obs"], intv_key, query_string)

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, query["obs"], Exp.Synthetic_Sample_Size)
            generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, query["obs"])
            obs_tvd, obs_kl = match_with_true_dist(Exp, query["obs"], generated_labels_full, true_dist, feat, doPrint=False)  # get it from scm

            print(f'{intv_key}: tvd:{obs_tvd}, kl:{obs_kl} and tvd<={np.sqrt(0.5 * obs_kl)}')
            tvd_sum += obs_tvd * obs_dist[tuple(intv_key.values())]
            kl_sum += obs_kl * obs_dist[tuple(intv_key.values())]

        print(f'--->Average tvd:{tvd_sum}, kl:{kl_sum} and tvd<={np.sqrt(0.5 * kl_sum)}')
        tvd_diff[query["expr"]].append(round(tvd_sum, 4))
        kl_diff[query["expr"]].append(round(kl_sum, 4))


    return tvd_diff, kl_diff


def evaluate_after_epochs(Exp, cur_mechs, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        for interv_no, key in enumerate(Exp.Data_intervs):

            intv_key = asKey(key)

            compare_Var =[]
            for lb in cur_mechs:
                if lb in Exp.image_labels:
                    continue

                if lb in dict(intv_key):
                    continue
                compare_Var.append(lb)

            obs_indices = [Exp.label_names.index(lb) for lb in compare_Var]

            current_real_label=[]
            if intv_key in dataset_dict:
                current_real_label = dataset_dict[intv_key][:, obs_indices].type(torch.LongTensor).view(-1, len(obs_indices)).to(
                    Exp.DEVICE)

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, dict(intv_key), compare_Var,Exp.Synthetic_Sample_Size, hard=True)
            generated_labels_full= map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)

            query_str = getdoKey(compare_Var, dict(intv_key))
            true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str)

            obs_tvd, obs_kl = match_with_true_dist(Exp, compare_Var, generated_labels_full,true_dist_dict, feat, doPrint=False)


            # query_str = "".join(x for x in compare_Var) + "|do" + "".join(x for x in intv_key.keys()) + "_" + "".join(str(x) for x in intv_key.values())

            # tvd_diff[query_str].append(round(obs_tvd , 4))  #todo: fix it
            # kl_diff[query_str].append(round(obs_kl , 4))  #Todo: fix it
            # all_generated_labels[intv_key] = torch.tensor(generated_labels_full)
            # all_real_labels[intv_key] = torch.tensor(current_real_label)



            #####-----------
            compare_Var = cur_mechs[0:-1]
            minibatch = 3
            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, compare_Var+[Exp.image_labels[0]], minibatch, hard=True)
            generated_image = generated_labels_dict[Exp.image_labels[0]]
            del generated_labels_dict[Exp.image_labels[0]]

            # y_dims = sum([Exp.label_dim[lb]["feature"] for lb in compare_Var])
            # ret = list(generated_labels_dict.values())
            # generated_labels_ig = torch.cat(ret, 1).view(-1, y_dims)
            generated_labels_ig = map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)



            for grow, genimg in zip(generated_labels_ig, generated_image):
                print("gen", grow)
                genimg = genimg.permute(1, 2, 0).detach().cpu().numpy()
                # plot_dataset_digits(1, 2, [obsimg, genimg], f'Real {Ores_digit[id]}')
                plot_trained_digits(1, 1, [genimg], f'Real {grow}')
                # fig, ax = plt.subplots()
                # ax.set_title(f'Real {Ores_digit[id]}')
                # plt.imshow(imggg1)
                # plt.show()

            ####------------


        # tvd_diff, kl_diff = get_observational_loss(Exp, Exp.label_names, label_generators, tvd_diff, kl_diff)
        # tvd_diff, kl_diff = get_expected_loss_interventions(Exp, cur_mechs,  label_generators, tvd_diff, kl_diff)


        # save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
        #              tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)



    for gen in label_generators:
        label_generators[gen].train()

    # ll = -min(10, len(list(tvd_diff.values())[0]))
    # # printing loss
    # for dist in tvd_diff:
    #     print("###", dist, " loss%:", tvd_diff[dist][ll:])
    # print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff


